{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "66ee51fa-4f4a-48a9-b915-9319b337a99c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 测试脚本\n",
    "# 1.support set每一类的图片输入特征提取网络获取embedding\n",
    "# 2.计算这一类的平均值，再归一化得到u1,u2,u3\n",
    "# 3.query输入预训练网络得到特征向量，归一化，然后与u1,u2,u3堆成矩阵M乘积得到余弦相似度"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2add4b8-c71e-466a-a501-a160e5f98d55",
   "metadata": {
    "tags": []
   },
   "source": [
    "### 计算两个向量的余弦相似度\n",
    "$$cos(\\theta)=\\frac{ a \\cdot b}{||a||\\times||b||} $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7570a9c1-eac9-489e-8792-432f06482542",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-08-25 23:15:46.234851: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import os\n",
    "import random\n",
    "import glob\n",
    "\n",
    "# 这里image需要重命名一下\n",
    "from tensorflow.keras.preprocessing import image as keras_image\n",
    "from tensorflow.keras.applications.resnet import preprocess_input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7f7a44cf-0c72-415d-a101-b75478cf3905",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-08-25 23:15:47.264143: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcuda.so.1\n",
      "2022-08-25 23:15:47.305537: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties: \n",
      "pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 3090 computeCapability: 8.6\n",
      "coreClock: 1.695GHz coreCount: 82 deviceMemorySize: 23.70GiB deviceMemoryBandwidth: 871.81GiB/s\n",
      "2022-08-25 23:15:47.305568: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0\n",
      "2022-08-25 23:15:47.309129: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11\n",
      "2022-08-25 23:15:47.309158: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublasLt.so.11\n",
      "2022-08-25 23:15:47.310329: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcufft.so.10\n",
      "2022-08-25 23:15:47.310601: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcurand.so.10\n",
      "2022-08-25 23:15:47.311440: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcusolver.so.11\n",
      "2022-08-25 23:15:47.312143: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcusparse.so.11\n",
      "2022-08-25 23:15:47.312274: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudnn.so.8\n",
      "2022-08-25 23:15:47.313834: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0\n",
      "2022-08-25 23:15:47.314430: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2022-08-25 23:15:47.323879: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties: \n",
      "pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 3090 computeCapability: 8.6\n",
      "coreClock: 1.695GHz coreCount: 82 deviceMemorySize: 23.70GiB deviceMemoryBandwidth: 871.81GiB/s\n",
      "2022-08-25 23:15:47.326048: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0\n",
      "2022-08-25 23:15:47.326155: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0\n",
      "2022-08-25 23:15:47.797712: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1258] Device interconnect StreamExecutor with strength 1 edge matrix:\n",
      "2022-08-25 23:15:47.797757: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1264]      0 \n",
      "2022-08-25 23:15:47.797765: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1277] 0:   N \n",
      "2022-08-25 23:15:47.800359: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1418] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 22318 MB memory) -> physical GPU (device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:01:00.0, compute capability: 8.6)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    }
   ],
   "source": [
    "embedding = tf.keras.models.load_model('./latest_new.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b958bd89-45b4-4ac3-9e23-5d4ccd32c362",
   "metadata": {},
   "outputs": [],
   "source": [
    "# embedding.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e8085797-cba9-4354-9729-199405a7f770",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extractFeat(model,filename):\n",
    "    # 和训练时一样的数据预处理\n",
    "    img = keras_image.load_img(filename, target_size= (200, 200) )\n",
    "    img = keras_image.img_to_array(img)\n",
    "    img = preprocess_input(img)\n",
    "    # 扩张维度\n",
    "    image = np.expand_dims(img, axis=0)\n",
    "    # 预测\n",
    "    feat = model.predict(image)\n",
    "    # L2归一化\n",
    "    norm_feat = feat[0] / np.linalg.norm(feat[0])\n",
    "    return norm_feat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3a7d2ed4-2cc1-4fbb-a0df-69473e808cc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取测试的support set 和query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b85a3f21-82c4-4323-a03c-938d49b41f26",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_classes = glob.glob('dogImages/train/*')[-5:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2bb33ffc-3a22-41fd-88bd-324d30f9537b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['dogImages/train/129.Tibetan_mastiff',\n",
       " 'dogImages/train/130.Welsh_springer_spaniel',\n",
       " 'dogImages/train/131.Wirehaired_pointing_griffon',\n",
       " 'dogImages/train/132.Xoloitzcuintli',\n",
       " 'dogImages/train/133.Yorkshire_terrier']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "query_classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "28647ca9-e247-4c08-afcf-39e5bce2f9eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_files = []\n",
    "# 遍历每一类\n",
    "for class_item in query_classes:\n",
    "    # 遍历下面所有文件\n",
    "    query_file_list = glob.glob(class_item + '/*')\n",
    "    query_files.append(query_file_list)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a007a1f1-1530-418e-9f52-36dca5290be0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# query_files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "b114a451-6099-4bc1-8fe2-e127905392cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 遍历每一类，选择其中前两个作为support set计算平均值，第三个当成查询用query\n",
    "class_u = []\n",
    "for class_images in query_files:\n",
    "    # 获取该类图片\n",
    "    feat_A = extractFeat(embedding,class_images[0]) \n",
    "    feat_B = extractFeat(embedding,class_images[1])     \n",
    "    # 平均值\n",
    "    feat_avg = (feat_A + feat_B) / 2\n",
    "    # 归一化\n",
    "    feat_norm = feat_avg / np.linalg.norm(feat_avg)\n",
    "    class_u.append(feat_norm)\n",
    "# 转为numpy数组    \n",
    "class_u = np.array(class_u)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "7984e55f-356a-480a-9a1f-cc707f5b8d54",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5, 256)"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class_u.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "188b30d0-23ce-4f55-bde6-94933d283ced",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.0778444 , -0.08980139,  0.01660734, ...,  0.07267957,\n",
       "         0.1185353 , -0.00660693],\n",
       "       [ 0.10270932, -0.06751491,  0.01926412, ..., -0.0031268 ,\n",
       "         0.09171963,  0.05606749],\n",
       "       [ 0.12412076, -0.12937087,  0.0158099 , ..., -0.01265056,\n",
       "         0.071591  ,  0.08709649],\n",
       "       [ 0.10841086, -0.12155013, -0.00542881, ..., -0.02353824,\n",
       "        -0.00940758,  0.08753707],\n",
       "       [ 0.10181352, -0.10390382,  0.0849184 , ..., -0.01457242,\n",
       "         0.06649275,  0.02868202]], dtype=float32)"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class_u"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "id": "662888b5-c9f4-4175-aaec-fabb7fcb967d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 随机选一个query图片\n",
    "query_img = query_files[4][5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "id": "51ddda48-40a2-426f-8027-7ae621d48ef6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'dogImages/train/133.Yorkshire_terrier/Yorkshire_terrier_08319.jpg'"
      ]
     },
     "execution_count": 105,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "query_img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "id": "bec30dc6-8724-42d0-a151-8a6787b544c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取特征\n",
    "query_feat = extractFeat(embedding,query_img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "id": "5cea0612-6759-4813-91e6-97b3556172ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# query_feat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "id": "aa377ca7-9dd4-4430-96ce-155e0404f1f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 计算余弦相似度\n",
    "class_prob = np.dot(class_u,query_feat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "id": "0b77f2c9-53c2-4182-8f59-7075e7a72540",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.3200401 , 0.53333116, 0.90293527, 0.7195747 , 0.9551848 ],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 109,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class_prob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "id": "fd45bb0e-4d8a-45b5-b248-02120d3cdef4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'dogImages/train/133.Yorkshire_terrier'"
      ]
     },
     "execution_count": 110,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 查询对应类别\n",
    "query_classes[np.argmax(class_prob)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc70eb2f-824d-4cae-8cb2-d40bc176edf7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "5470d929-6329-4fb5-8e45-7448103982a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 测试猫咪"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "d8d6ac12-a581-49aa-aaee-bb24fe41e18d",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = ['Abyssinian','Aegean','British_Shorthair','Donskoy','Persian']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "id": "16e00bf2-ca5b-43cb-b525-8c703364cb23",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 遍历每一类\n",
    "class_u = []\n",
    "for label in labels:\n",
    "    # 获取该类图片\n",
    "    files = glob.glob('./cat/'+label+'/*')\n",
    "    feats = [extractFeat(embedding,test_file) for test_file in files]\n",
    "    # 平均值\n",
    "    feat_avg = (feats[0] + feats[1]) /2 \n",
    "    # 归一化\n",
    "    feat_norm = feat_avg / np.linalg.norm(feat_avg)\n",
    "    class_u.append(feat_norm)\n",
    "    \n",
    "class_u = np.array(class_u)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "id": "8e4d9e50-2e19-468d-afe2-aab9bd846ac1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5, 256)"
      ]
     },
     "execution_count": 124,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class_u.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "id": "6da18a76-123f-46fa-8c9b-2f541a056e4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_list = glob.glob('./cat/*_query*')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "id": "84e9aa5e-e76a-47c3-8ddf-4e6f891e1835",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['./cat/Abyssinian_query.jpg',\n",
       " './cat/Aegean_query.jpg',\n",
       " './cat/British_Shorthair_query.jpeg',\n",
       " './cat/Donskoy_query.jpeg',\n",
       " './cat/Persian_query.jpeg']"
      ]
     },
     "execution_count": 126,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "query_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "id": "bca38c5c-364c-45ae-a5e0-7038e779c813",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_feat = extractFeat(embedding,query_list[3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "id": "6fb3bafe-01c2-4468-a172-b0a4d73fe4ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "class_prob = np.dot(class_u,query_feat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "id": "9c36aec8-1901-44d3-9b01-1be8ff1bdb6a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.83904886, 0.843537  , 0.8892164 , 0.81702244, 0.68522656],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 141,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class_prob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "id": "11a59e4e-22ba-47bb-867e-e2f4ce640889",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'British_Shorthair'"
      ]
     },
     "execution_count": 142,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels[np.argmax(class_prob)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcb0f3dd-5189-4cff-9311-10a6a353479c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2156c914-4f7d-463f-9544-fa60442c23c2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43c53f93-3c61-4a47-814e-7f83d024f3d2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d33e10c5-6410-459a-ab90-f575259c3d0c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2851bfa-0b7f-4e4d-b589-6828a7762be5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3107b27-74c7-4d09-9ef9-67054e48a5e4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
